﻿#pragma kernel Clear
#pragma kernel Initialize
#pragma kernel Project
#pragma kernel Apply
#pragma kernel ProjectRenderable

#include "MathUtils.cginc"
#include "AtomicDeltas.cginc"
#include "ColliderDefinitions.cginc"
#include "Rigidbody.cginc"

StructuredBuffer<int> particleIndices;
StructuredBuffer<int> colliderIndices;
StructuredBuffer<float4> offsets;
StructuredBuffer<quaternion> restDarboux;
StructuredBuffer<float2> stiffnesses;
RWStructuredBuffer<float4> lambdas;

StructuredBuffer<transform> transforms;
StructuredBuffer<shape> shapes;
RWStructuredBuffer<rigidbody> RW_rigidbodies;

RWStructuredBuffer<float4> RW_positions;
RWStructuredBuffer<quaternion> RW_orientations;

StructuredBuffer<float4> positions;
StructuredBuffer<quaternion> orientations;
StructuredBuffer<float4> prevPositions;
StructuredBuffer<float> invMasses;
StructuredBuffer<float> invRotationalMasses;

StructuredBuffer<inertialFrame> inertialSolverFrame;

// Variables set from the CPU
uint activeConstraintCount;
float stepTime;
float substepTime;
float timeLeft;
int steps;
float sorFactor;

[numthreads(128, 1, 1)]
void Clear (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int colliderIndex = colliderIndices[i];

    // no collider to pin to, so ignore the constraint.
    if (colliderIndex < 0)
        return;

    int rigidbodyIndex = shapes[colliderIndex].rigidbodyIndex;

    if (rigidbodyIndex >= 0)
    {
        int orig;
        InterlockedExchange(RW_rigidbodies[rigidbodyIndex].constraintCount, 0, orig);
    }
}

[numthreads(128, 1, 1)]
void Initialize (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int colliderIndex = colliderIndices[i];

    // no collider to pin to, so ignore the constraint.
    if (colliderIndex < 0)
        return;

    int rigidbodyIndex = shapes[colliderIndex].rigidbodyIndex;

    if (rigidbodyIndex >= 0)
    {
        InterlockedAdd(RW_rigidbodies[rigidbodyIndex].constraintCount, 1);
    }
}

[numthreads(128, 1, 1)]
void Project (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= activeConstraintCount) return;

    int particleIndex = particleIndices[i];
    int colliderIndex = colliderIndices[i];

    // no collider to pin to, so ignore the constraint.
    if (colliderIndex < 0)
        return;

    int rigidbodyIndex = shapes[colliderIndex].rigidbodyIndex;

    float frameEnd = stepTime * steps;
    float substepsToEnd = timeLeft / substepTime;

    // calculate time adjusted compliances
    float2 compliances = stiffnesses[i].xy / (substepTime * substepTime);

    // project particle position to the end of the full step:
    float4 particlePosition = lerp(prevPositions[particleIndex], positions[particleIndex], substepsToEnd);

    // express pin offset in world space:
    float4 worldPinOffset = transforms[colliderIndex].TransformPoint(offsets[i]);
    float4 predictedPinOffset = worldPinOffset;
    quaternion predictedRotation = transforms[colliderIndex].rotation;

    float rigidbodyLinearW = 0;
    float rigidbodyAngularW = 0;

    if (rigidbodyIndex >= 0)
    {
        rigidbody rb = rigidbodies[rigidbodyIndex];

        // predict offset point position using rb velocity at that point (can't integrate transform since position != center of mass)
        float4 velocityAtPoint = GetRigidbodyVelocityAtPoint(rigidbodies[rigidbodyIndex],inertialSolverFrame[0].frame.InverseTransformPoint(worldPinOffset), 
                                                             asfloat(linearDeltasAsInt[rigidbodyIndex]), 
                                                             asfloat(angularDeltasAsInt[rigidbodyIndex]), inertialSolverFrame[0]);

        predictedPinOffset = IntegrateLinear(predictedPinOffset, inertialSolverFrame[0].frame.TransformVector(velocityAtPoint), frameEnd);

        // predict rotation at the end of the step:
        predictedRotation = IntegrateAngular(predictedRotation, rb.angularVelocity + asfloat(angularDeltasAsInt[rigidbodyIndex]), frameEnd);

        // calculate linear and angular rigidbody effective masses (mass splitting: multiply by constraint count)
        rigidbodyLinearW = rb.inverseMass * rb.constraintCount; 
        rigidbodyAngularW = RotationalInvMass(rb.inverseInertiaTensor,
                                              worldPinOffset - rb.com,
                                              normalizesafe(inertialSolverFrame[0].frame.TransformPoint(particlePosition) - predictedPinOffset)) * rb.constraintCount;

    }

    // Transform pin position to solver space for constraint solving:
    predictedPinOffset = inertialSolverFrame[0].frame.InverseTransformPoint(predictedPinOffset);
    predictedRotation = qmul(q_conj(inertialSolverFrame[0].frame.rotation), predictedRotation);

    float4 gradient = particlePosition - predictedPinOffset;
    float constraint = length(gradient);
    float4 gradientDir = gradient / (constraint + EPSILON);

    float4 lambda = lambdas[i];
    float linearDLambda = (-constraint - compliances.x * lambda.w) / (invMasses[particleIndex] + rigidbodyLinearW + rigidbodyAngularW + compliances.x + EPSILON);
    lambda.w += linearDLambda;
    float4 correction = linearDLambda * gradientDir;
    
    AddPositionDelta(particleIndex, correction * invMasses[particleIndex] / substepsToEnd);

    if (rigidbodyIndex >= 0)
    {
        ApplyImpulse(rigidbodyIndex,
                     -correction / frameEnd,
                     inertialSolverFrame[0].frame.InverseTransformPoint(worldPinOffset),
                     inertialSolverFrame[0].frame);
    }

    if (rigidbodyAngularW > 0 || invRotationalMasses[particleIndex] > 0)
    {
        // bend/twist constraint:
        quaternion omega = qmul(q_conj(orientations[particleIndex]), predictedRotation);   //darboux vector

        quaternion omega_plus;
        omega_plus = omega + restDarboux[i];        //delta Omega with - omega_0
        omega -= restDarboux[i];                    //delta Omega with + omega_0
        if (dot(omega, omega) > dot(omega_plus, omega_plus))
            omega = omega_plus;

        float3 dlambda = (omega.xyz - compliances.y * lambda.xyz) / (compliances.y + invRotationalMasses[particleIndex] + rigidbodyAngularW + EPSILON);
        lambda.xyz += dlambda;

        //discrete Darboux vector does not have vanishing scalar part:
        quaternion dlambdaQ = quaternion(dlambda[0], dlambda[1], dlambda[2], 0);

        quaternion orientDelta = asfloat(orientationDeltasAsInt[particleIndex]);
        orientDelta += qmul(predictedRotation, dlambdaQ) * invRotationalMasses[particleIndex] / substepsToEnd;
        orientationDeltasAsInt[particleIndex] = asuint(orientDelta);
        orientationConstraintCounts[particleIndex]++;

        if (rigidbodyIndex >= 0)
        {
            ApplyDeltaQuaternion(rigidbodyIndex,
                                 predictedRotation,
                                 -qmul(orientations[particleIndex], dlambdaQ) * rigidbodyAngularW,
                                 inertialSolverFrame[0].frame, stepTime);
        }
    }

    lambdas[i] = lambda;
}

[numthreads(128, 1, 1)]
void Apply (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int p = particleIndices[i];

    ApplyPositionDelta(RW_positions, p, sorFactor);
    ApplyOrientationDelta(RW_orientations, p, sorFactor);
}

[numthreads(128, 1, 1)]
void ProjectRenderable (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int particleIndex = particleIndices[i];
    int colliderIndex = colliderIndices[i];

    // no collider to pin to or projection deactivated, so ignore the constraint.
    if (colliderIndex < 0 || offsets[i].w < 0.5f)
        return;

    transform attachmentMatrix = inertialSolverFrame[0].frame.Inverse().Multiply(transforms[colliderIndex]);

    RW_positions[particleIndex] = attachmentMatrix.TransformPoint(offsets[i]);
    if (stiffnesses[i].y < 10000)
        RW_orientations[particleIndex] = qmul(attachmentMatrix.rotation, restDarboux[i]);
}